import pickle as pkl
# from sklearn.metrics import roc_auc_score
from sklearn.preprocessing import label_binarize
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt
from sklearn.metrics import f1_score, precision_score, recall_score
import numpy as np

# Please specify the parent directory folder for save_output.
path = ""

def cal_total_acc(p, l):
    # print(p.argmax(1).shape)
    # print(l.shape)
    return (p.argmax(1) == l).sum() / l.shape[0]
        

def cal_auc(predict_list, label_list, classes):
    y_true_bin = label_binarize(label_list, classes=classes)
    n_classes = y_true_bin.shape[1]
    fpr = dict()
    tpr = dict()
    roc_auc = dict()

    for i in range(n_classes):
        fpr[i], tpr[i], _ = roc_curve(y_true_bin[:, i], predict_list[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])

    # 计算宏平均AUC
    fpr["micro"], tpr["micro"], _ = roc_curve(y_true_bin.ravel(), predict_list.ravel())
    roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

    return roc_auc

def auc_roc_curve(predict_list, label_list, classes):

    y_true_bin = label_binarize(label_list, classes=classes)
    n_classes = y_true_bin.shape[1]
    fpr = dict()
    tpr = dict()
    roc_auc = dict()

    for i in range(n_classes):
        fpr[i], tpr[i], _ = roc_curve(y_true_bin[:, i], predict_list[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])

    # 计算宏平均AUC
    fpr["micro"], tpr["micro"], _ = roc_curve(y_true_bin.ravel(), predict_list.ravel())
    roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

    # 打印每个类别的AUC
    for i in range(n_classes):
        print(f"Class {i} AUC: {roc_auc[i]}")

    # 绘制ROC曲线
    plt.figure()
    # 使用matplotlib的色谱获取13个不同的颜色
    colors = plt.cm.get_cmap('tab20', n_classes)  # 使用tab20色谱，最多20个类别
    for i in range(n_classes):
        plt.plot(fpr[i], tpr[i], color=colors(i), lw=2, label=f'Class {i} (AUC = {roc_auc[i]:0.2f})')

    plt.plot([0, 1], [0, 1], color='gray', linestyle='--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curve')
    plt.legend(loc="lower right")
    plt.show()

    # 打印宏平均AUC
    print(f"Macro average AUC: {roc_auc['micro']}")

def cal_class_acc(predict_list, label_list, classes):
    predict_label = predict_list.argmax(-1)
    accuracy_list = []
    for i in classes:
        total_num = np.sum(label_list == i)
        if total_num == 0:
            print(f'Lack of sample of {i}')
            return None
        true_predict = np.sum((predict_label == i) & (label_list == i))
        # 计算该类别的准确率
        accuracy = true_predict / total_num
        accuracy_list.append(accuracy)
    return np.array(accuracy_list)

def cal_class_acc_print(predict_list, label_list, classes):
    predict_label = predict_list.argmax(-1)
    print(np.unique(predict_label))
    accuracy_list = []
    for i in classes:
        total_num = np.sum(label_list == i)
        if total_num == 0:
            print(f'Lack of sample of {i}')
            return None
        true_predict = np.sum((predict_label == i) & (label_list == i))
        # 计算该类别的准确率
        if i == 12:
            print('class 12 num:', total_num)
            print('true 12 location:', np.where(label_list == 12))
            print('predict 12 location:', np.where(predict_label == 12))
        accuracy = true_predict / total_num
        accuracy_list.append(accuracy)
    return np.array(accuracy_list)


with open(path + '/save_output/output.pkl', 'rb') as f:
    save_dict = pkl.load(f)

probability = save_dict['probability']
print('test_acc:', probability.shape)
label = save_dict['label']
classes = save_dict['classes']

print('*'*50)
# load test acc during training
"""test acc is calculated by: (1)calculate each accuracy of each vessel 
(2)calculate the mean of accuracy from vessel accuracy
"""
test_acc = save_dict['acc']
print('test_acc:', save_dict['acc'])

# cal total acc
"""total acc is calculated by: (1)All individual branches are extracted from 
the vessel structures and aggregated to compute the classification accuracy across branches.
"""
total_acc = cal_total_acc(probability, label)
print('total_acc:', total_acc)
print('*'*50)

print('*'*50)
# cal auc
auc_list = []
auc_value = cal_auc(probability, label, classes)
auc_list.append([auc_value[i] for i in range(len(classes))])
mean_auc = np.mean(np.array(auc_list), 0)
for i in range(len(mean_auc)):
    print(f'Class {i} AUC: {mean_auc[i]}')
print('AUC-ROC:', np.mean(mean_auc))
print('*'*50)

print('*'*50)
# cal per-class acc
class_acc = cal_class_acc(probability, label, classes)
for i in range(len(class_acc)):
    print(f'Class {i} ACC: {class_acc[i]}')
print('*'*50)

print('*'*50)
# cal percision and recall
precision = precision_score(label, probability.argmax(1), average='weighted')  
recall = recall_score(label, probability.argmax(1), average='weighted')
print(f"Precision: {precision}")
print(f"Recall: {recall}")
print('*'*50)

print('*'*50)
# cal percision and recall
f1_weighted = f1_score(label, probability.argmax(1), average='weighted')
print(f"F1 Score (Weighted Average): {f1_weighted}")
print('*'*50)

